{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Defining Custom Tools\n",
    "When constructing your own agent, you will need to provide it with a list of Tools that it can use. Besides the actual function that is called, the Tool consists of several components:\n",
    "\n",
    "在构建您自己的代理时，您需要为其提供可以使用的工具列表。除了调用的实际函数外，该工具还由几个组件组成：\n",
    "\n",
    "- name (str), is required and must be unique within a set of tools provided to an agent\n",
    "<br>name （str） 是必需的，并且在提供给代理的一组工具中必须是唯一的\n",
    "- description (str), is optional but recommended, as it is used by an agent to determine tool use\n",
    "<br>description （STR） 是可选的，但建议使用，因为代理使用它来确定工具的使用情况\n",
    "- args_schema (Pydantic BaseModel), is optional but recommended, can be used to provide more information (e.g., few-shot examples) or validation for expected parameters.\n",
    "<br>args_schema （Pydantic BaseModel） 是可选的，但建议使用，可用于提供更多信息（例如，少量示例）或验证预期参数。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## @tool decorator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "search\n",
      "search(query: str) -> str - Look up things online.\n",
      "{'query': {'title': 'Query', 'type': 'string'}}\n"
     ]
    }
   ],
   "source": [
    "# Import things that are needed generically\n",
    "from langchain.pydantic_v1 import BaseModel, Field\n",
    "from langchain.tools import BaseTool, StructuredTool, tool\n",
    "\n",
    "@tool\n",
    "def search(query: str) -> str:\n",
    "    \"\"\"Look up things online.\"\"\"\n",
    "    return \"LangChain\"\n",
    "\n",
    "print(search.name)\n",
    "print(search.description)\n",
    "print(search.args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "multiply\n",
      "multiply(a: int, b: int) -> int - Multiply two numbers.\n",
      "{'a': {'title': 'A', 'type': 'integer'}, 'b': {'title': 'B', 'type': 'integer'}}\n"
     ]
    }
   ],
   "source": [
    "@tool\n",
    "def multiply(a: int, b: int) -> int:\n",
    "    \"\"\"Multiply two numbers.\"\"\"\n",
    "    return a * b\n",
    "\n",
    "print(multiply.name)\n",
    "print(multiply.description)\n",
    "print(multiply.args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "search-tool\n",
      "search-tool(query: str) -> str - Look up things online.\n",
      "{'query': {'title': 'Query', 'description': 'should be a search query', 'type': 'string'}}\n",
      "True\n"
     ]
    }
   ],
   "source": [
    "class SearchInput(BaseModel):\n",
    "    query: str = Field(description=\"should be a search query\")\n",
    "\n",
    "\n",
    "@tool(\"search-tool\", args_schema=SearchInput, return_direct=True)\n",
    "def search(query: str) -> str:\n",
    "    \"\"\"Look up things online.\"\"\"\n",
    "    return \"LangChain\"\n",
    "\n",
    "print(search.name)\n",
    "print(search.description)\n",
    "print(search.args)\n",
    "print(search.return_direct)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Subclass BaseTool"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Optional, Type\n",
    "\n",
    "from langchain.callbacks.manager import (\n",
    "    AsyncCallbackManagerForToolRun,\n",
    "    CallbackManagerForToolRun,\n",
    ")\n",
    "\n",
    "\n",
    "class SearchInput(BaseModel):\n",
    "    query: str = Field(description=\"should be a search query\")\n",
    "\n",
    "\n",
    "class CalculatorInput(BaseModel):\n",
    "    a: int = Field(description=\"first number\")\n",
    "    b: int = Field(description=\"second number\")\n",
    "\n",
    "\n",
    "class CustomSearchTool(BaseTool):\n",
    "    name = \"custom_search\"\n",
    "    description = \"useful for when you need to answer questions about current events\"\n",
    "    args_schema: Type[BaseModel] = SearchInput\n",
    "\n",
    "    def _run(\n",
    "        self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None\n",
    "    ) -> str:\n",
    "        \"\"\"Use the tool.\"\"\"\n",
    "        return \"LangChain\"\n",
    "\n",
    "    async def _arun(\n",
    "        self, query: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None\n",
    "    ) -> str:\n",
    "        \"\"\"Use the tool asynchronously.\"\"\"\n",
    "        raise NotImplementedError(\"custom_search does not support async\")\n",
    "\n",
    "\n",
    "class CustomCalculatorTool(BaseTool):\n",
    "    name = \"Calculator\"\n",
    "    description = \"useful for when you need to answer questions about math\"\n",
    "    args_schema: Type[BaseModel] = CalculatorInput\n",
    "    return_direct: bool = True\n",
    "\n",
    "    def _run(\n",
    "        self, a: int, b: int, run_manager: Optional[CallbackManagerForToolRun] = None\n",
    "    ) -> str:\n",
    "        \"\"\"Use the tool.\"\"\"\n",
    "        return a * b\n",
    "\n",
    "    async def _arun(\n",
    "        self,\n",
    "        a: int,\n",
    "        b: int,\n",
    "        run_manager: Optional[AsyncCallbackManagerForToolRun] = None,\n",
    "    ) -> str:\n",
    "        \"\"\"Use the tool asynchronously.\"\"\"\n",
    "        raise NotImplementedError(\"Calculator does not support async\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "custom_search\n",
      "useful for when you need to answer questions about current events\n",
      "{'query': {'title': 'Query', 'description': 'should be a search query', 'type': 'string'}}\n"
     ]
    }
   ],
   "source": [
    "search = CustomSearchTool()\n",
    "print(search.name)\n",
    "print(search.description)\n",
    "print(search.args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Calculator\n",
      "useful for when you need to answer questions about math\n",
      "{'a': {'title': 'A', 'description': 'first number', 'type': 'integer'}, 'b': {'title': 'B', 'description': 'second number', 'type': 'integer'}}\n",
      "True\n"
     ]
    }
   ],
   "source": [
    "multiply = CustomCalculatorTool()\n",
    "print(multiply.name)\n",
    "print(multiply.description)\n",
    "print(multiply.args)\n",
    "print(multiply.return_direct)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## StructuredTool dataclass\n",
    "You can also use a StructuredTool dataclass. This methods is a mix between the previous two. It's more convenient than inheriting from the BaseTool class, but provides more functionality than just using a decorator.\n",
    "\n",
    "还可以使用 StructuredTool 数据类。这种方法是前两种方法的混合。它比从 BaseTool 类继承更方便，但比仅使用装饰器提供更多功能。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Search\n",
      "Search(query: str) - useful for when you need to answer questions about current events\n",
      "{'query': {'title': 'Query', 'type': 'string'}}\n"
     ]
    }
   ],
   "source": [
    "def search_function(query: str):\n",
    "    return \"LangChain\"\n",
    "\n",
    "\n",
    "search = StructuredTool.from_function(\n",
    "    func=search_function,\n",
    "    name=\"Search\",\n",
    "    description=\"useful for when you need to answer questions about current events\",\n",
    "    # coroutine= ... <- you can specify an async method if desired as well\n",
    ")\n",
    "print(search.name)\n",
    "print(search.description)\n",
    "print(search.args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Calculator\n",
      "Calculator(a: int, b: int) -> int - multiply numbers\n",
      "{'a': {'title': 'A', 'description': 'first number', 'type': 'integer'}, 'b': {'title': 'B', 'description': 'second number', 'type': 'integer'}}\n"
     ]
    }
   ],
   "source": [
    "class CalculatorInput(BaseModel):\n",
    "    a: int = Field(description=\"first number\")\n",
    "    b: int = Field(description=\"second number\")\n",
    "\n",
    "\n",
    "def multiply(a: int, b: int) -> int:\n",
    "    \"\"\"Multiply two numbers.\"\"\"\n",
    "    return a * b\n",
    "\n",
    "\n",
    "calculator = StructuredTool.from_function(\n",
    "    func=multiply,\n",
    "    name=\"Calculator\",\n",
    "    description=\"multiply numbers\",\n",
    "    args_schema=CalculatorInput,\n",
    "    return_direct=True,\n",
    "    # coroutine= ... <- you can specify an async method if desired as well\n",
    ")\n",
    "print(calculator.name)\n",
    "print(calculator.description)\n",
    "print(calculator.args)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Handling Tool Errors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'The following errors occurred during tool execution:\\nThe search tool1 is not available.Please try another tool.'"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from langchain_core.tools import ToolException\n",
    "\n",
    "\n",
    "def search_tool1(s: str):\n",
    "    raise ToolException(\"The search tool1 is not available.\")\n",
    "\n",
    "def _handle_error(error: ToolException) -> str:\n",
    "    return (\n",
    "        \"The following errors occurred during tool execution:\"\n",
    "        + error.args[0]\n",
    "        + \"Please try another tool.\"\n",
    "    )\n",
    "\n",
    "search = StructuredTool.from_function(\n",
    "    func=search_tool1,\n",
    "    name=\"Search_tool1\",\n",
    "    description=\"A bad tool\",\n",
    "    handle_tool_error=_handle_error,\n",
    ")\n",
    "\n",
    "search.run(\"test\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "langchain0_1",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
